{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkdnLiKk71g-"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "0asMuNro71hA"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/building_your_own_federated_learning_algorithm\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/building_your_own_federated_learning_algorithm.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/master/docs/tutorials/building_your_own_federated_learning_algorithm.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/building_your_own_federated_learning_algorithm.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MnUwFbCAKB2r"
      },
      "source": [
        "## Before we start\n",
        "\n",
        "Before we start, please run the following to make sure that your environment is\n",
        "correctly setup. If you don't see a greeting, please refer to the\n",
        "[Installation](../install.md) guide for instructions. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZrGitA_KnRO0"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow-federated-nightly\n",
        "!pip install --quiet --upgrade nest-asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HGTM6tWOLo8M"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yr3ztf28fa1F"
      },
      "source": [
        "**NOTE**: This colab has been verified to work with the [latest released version](https://github.com/tensorflow/federated#compatibility) of the `tensorflow_federated` pip package, but the Tensorflow Federated project is still in pre-release development and may not work on `master`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Zv28F7QLo8O"
      },
      "source": [
        "# Building Your Own Federated Learning Algorithm\n",
        "\n",
        "In the [image classification](federated_learning_for_image_classification.ipynb) and\n",
        "[text generation](federated_learning_for_text_generation.ipynb) tutorials, we learned how to set up model and data pipelines for Federated Learning (FL), and performed federated training via the `tff.learning` API layer of TFF.\n",
        "\n",
        "This is only the tip of the iceberg when it comes to FL research. In this tutorial, we discuss how to implement federated learning algorithms *without* deferring to the `tff.learning` API. We aim to accomplish the following:\n",
        "\n",
        "**Goals:**\n",
        "\n",
        "\n",
        "*   Understand the general structure of federated learning algorithms.\n",
        "*   Explore the *Federated Core* of TFF.\n",
        "*   Use the Federated Core to implement Federated Averaging directly.\n",
        "\n",
        "While this tutorial is self-contained, we recommend first reading the [image classification](federated_learning_for_image_classification.ipynb) and\n",
        "[text generation](federated_learning_for_text_generation.ipynb) tutorials.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQ_N9XbULo8P"
      },
      "source": [
        "## Preparing the input data\n",
        "We first load and preprocess the EMNIST dataset included in TFF. For more details, see the [image classification](federated_learning_for_image_classification.ipynb) tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-WdnFluLLo8P"
      },
      "outputs": [],
      "source": [
        "emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kq8893GogB8E"
      },
      "source": [
        "In order to feed the dataset into our model, we flatten the data, and convert each example into a tuple of the form `(flattened_image_vector, label)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Blrh8zJgLo8R"
      },
      "outputs": [],
      "source": [
        "NUM_CLIENTS = 10\n",
        "BATCH_SIZE = 20\n",
        "\n",
        "def preprocess(dataset):\n",
        "\n",
        "  def batch_format_fn(element):\n",
        "    \"\"\"Flatten a batch of EMNIST data and return a (features, label) tuple.\"\"\"\n",
        "    return (tf.reshape(element['pixels'], [-1, 784]), \n",
        "            tf.reshape(element['label'], [-1, 1]))\n",
        "\n",
        "  return dataset.batch(BATCH_SIZE).map(batch_format_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Piy8EzqmgNqV"
      },
      "source": [
        "We now select a small number of clients, and apply the preprocessing above to their datasets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-vYM_IT7Lo8W"
      },
      "outputs": [],
      "source": [
        "client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]\n",
        "federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))\n",
        "  for x in client_ids\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gNO_Y9j_Lo8X"
      },
      "source": [
        "## Preparing the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LJ0I89ixz8yV"
      },
      "source": [
        "We use the same model as in the [image classification](federated_learning_for_image_classification.ipynb) tutorial. This model (implemented via `tf.keras`) has a single hidden layer, followed by a softmax layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yfld4oFNLo8Y"
      },
      "outputs": [],
      "source": [
        "def create_keras_model():\n",
        "  initializer = tf.keras.initializers.GlorotNormal(seed=0)\n",
        "  return tf.keras.models.Sequential([\n",
        "      tf.keras.layers.Input(shape=(784,)),\n",
        "      tf.keras.layers.Dense(10, kernel_initializer=initializer),\n",
        "      tf.keras.layers.Softmax(),\n",
        "  ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vLln0Q8G0Bky"
      },
      "source": [
        "In order to use this model in TFF, we wrap the Keras model as a [`tff.learning.Model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model). This allows us to perform the model's [forward pass](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model#forward_pass) within TFF, and [extract model outputs](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model#report_local_outputs). For more details, also see the [image classification](federated_learning_for_image_classification.ipynb) tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SPwbipTNLo8a"
      },
      "outputs": [],
      "source": [
        "def model_fn():\n",
        "  keras_model = create_keras_model()\n",
        "  return tff.learning.from_keras_model(\n",
        "      keras_model,\n",
        "      input_spec=federated_train_data[0].element_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pCxa44rFiere"
      },
      "source": [
        "While we used `tf.keras` to create a `tff.learning.Model`, TFF supports much more general models. These models have the following relevant attributes capturing the model weights:\n",
        "\n",
        "*   `trainable_variables`: An iterable of the tensors corresponding to trainable layers.\n",
        "*   `non_trainable_variables`: An iterable of the tensors corresponding to non-trainable layers.\n",
        "\n",
        "For our purposes, we will only use the `trainable_variables`. (as our model only has those!)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fPOWP2JjsfTk"
      },
      "source": [
        "# Building your own Federated Learning algorithm\n",
        "\n",
        "While the `tff.learning` API allows one to create many variants of Federated Averaging, there are other federated algorithms that do not fit neatly into this framework. For example, you may want to add regularization, clipping, or more complicated algorithms such as [federated GAN training](https://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/research/gans). You may also be instead be interested in [federated analytics](https://ai.googleblog.com/2020/05/federated-analytics-collaborative-data.html)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "50N36Zz8qyY-"
      },
      "source": [
        "For these more advanced algorithms, we'll have to write our own custom algorithm using TFF. In many cases, federated algorithms have 4 main components:\n",
        "\n",
        "1. A server-to-client broadcast step.\n",
        "2. A local client update step.\n",
        "3. A client-to-server upload step.\n",
        "4. A server update step."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jH8s0GRdQt3b"
      },
      "source": [
        "In TFF, we generally represent federated algorithms as a [`tff.templates.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess) (which we refer to as just an `IterativeProcess` throughout). This is a class that contains `initialize` and `next` functions. Here, `initialize` is used to initialize the server, and `next` will perform one communication round of the federated algorithm. Let's write a skeleton of what our iterative process for FedAvg should look like.\n",
        "\n",
        "First, we have an initialize function that simply creates a `tff.learning.Model`, and returns its trainable weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ylLpRa7T5DDh"
      },
      "outputs": [],
      "source": [
        "def initialize_fn():\n",
        "  model = model_fn()\n",
        "  return model.trainable_variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nb1-XAK8fB2A"
      },
      "source": [
        "This function looks good, but as we will see later, we will need to make a small modification to make it a \"TFF computation\".\n",
        "\n",
        "We also want to sketch the `next_fn`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IeHN-XLZfMso"
      },
      "outputs": [],
      "source": [
        "def next_fn(server_weights, federated_dataset):\n",
        "  # Broadcast the server weights to the clients.\n",
        "  server_weights_at_client = broadcast(server_weights)\n",
        "\n",
        "  # Each client computes their updated weights.\n",
        "  client_weights = client_update(federated_dataset, server_weights_at_client)\n",
        "\n",
        "  # The server averages these updates.\n",
        "  mean_client_weights = mean(client_weights)\n",
        "\n",
        "  # The server updates its model.\n",
        "  server_weights = server_update(mean_client_weights)\n",
        "\n",
        "  return server_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uWXvjXPWeujU"
      },
      "source": [
        "We'll focus on implementing these four components separately. We first focus on the parts that can be implemented in pure TensorFlow, namely the client and server update steps.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3yKS4VkALo8g"
      },
      "source": [
        "## TensorFlow Blocks "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bxpNYucgLo8g"
      },
      "source": [
        "### Client update\n",
        "\n",
        "We will use our `tff.learning.Model` to do client training in essentially the same way you would train a TensorFlow model. In particular, we will use `tf.GradientTape` to compute the gradient on batches of data, then apply these gradient using a `client_optimizer`. We focus only on the trainable weights.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c5rHPKreLo8g"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def client_update(model, dataset, server_weights, client_optimizer):\n",
        "  \"\"\"Performs training (using the server model weights) on the client's dataset.\"\"\"\n",
        "  # Initialize the client model with the current server weights.\n",
        "  client_weights = model.trainable_variables\n",
        "  # Assign the server weights to the client model.\n",
        "  tf.nest.map_structure(lambda x, y: x.assign(y),\n",
        "                        client_weights, server_weights)\n",
        "\n",
        "  # Use the client_optimizer to update the local model.\n",
        "  for batch in dataset:\n",
        "    with tf.GradientTape() as tape:\n",
        "      # Compute a forward pass on the batch of data\n",
        "      outputs = model.forward_pass(batch)\n",
        "\n",
        "    # Compute the corresponding gradient\n",
        "    grads = tape.gradient(outputs.loss, client_weights)\n",
        "    grads_and_vars = zip(grads, client_weights)\n",
        "\n",
        "    # Apply the gradient using a client optimizer.\n",
        "    client_optimizer.apply_gradients(grads_and_vars)\n",
        "\n",
        "  return client_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pP0D9XtoLo8i"
      },
      "source": [
        "### Server Update\n",
        "\n",
        "The server update for FedAvg is simpler than the client update. We will implement \"vanilla\" federated averaging, in which we simply replace the server model weights by the average of the client model weights. Again, we only focus on the trainable weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rYxErLvHLo8i"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def server_update(model, mean_client_weights):\n",
        "  \"\"\"Updates the server model weights as the average of the client model weights.\"\"\"\n",
        "  model_weights = model.trainable_variables\n",
        "  # Assign the mean client weights to the server model.\n",
        "  tf.nest.map_structure(lambda x, y: x.assign(y),\n",
        "                        model_weights, mean_client_weights)\n",
        "  return model_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ddCklfWlVr1U"
      },
      "source": [
        "The snippet could be simplified by simply returning the `mean_client_weights`. However, more advanced implementations of Federated Averaging use `mean_client_weights` with more sophisticated techniques, such as momentum or adaptivity.\n",
        "\n",
        "**Challenge**: Implement a version of `server_update` that updates the server weights to be the midpoint of model_weights and mean_client_weights. (Note: This kind of \"midpoint\" approach is analogous to recent work on the [Lookahead optimizer](https://arxiv.org/abs/1907.08610)!)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KuP9g6RFLo8k"
      },
      "source": [
        "So far, we've only written pure TensorFlow code. This is by design, as TFF allows you to use much of the TensorFlow code you're already familiar with. However, now we have to specify the **orchestration logic**, that is, the logic that dictates what the server broadcasts to the client, and what the client uploads to the server.\n",
        "\n",
        "This will require the *Federated Core* of TFF."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0CgFLVPgLo8l"
      },
      "source": [
        "# Introduction to the Federated Core\n",
        "\n",
        "The Federated Core (FC) is a set of lower-level interfaces that serve as the foundation for the `tff.learning` API. However, these interfaces are not limited to learning. In fact, they can be used for analytics and many other computations over distributed data.\n",
        "\n",
        "At a high-level, the federated core is a development environment that enables compactly expressed program logic to combine TensorFlow code with distributed communication operators (such as distributed sums and broadcasts). The goal is to give researchers and practitioners expliict control over the distributed communication in their systems, without requiring system implementation details (such as specifying point-to-point network message exchanges).\n",
        "\n",
        "One key point is that TFF is designed for privacy-preservation. Therefore, it allows explicit control over where data resides, to prevent unwanted accumulation of data at the centralized server location."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYinjNqZLo8l"
      },
      "source": [
        "## Federated data\n",
        "\n",
        "A key concept in TFF is \"federated data\", which refers to a collection of data items hosted across a group of devices in a distributed system (eg. client datasets, or the server model weights). We model the entire collection of data items across all devices as a single *federated value*.\n",
        "\n",
        "For example, suppose we have client devices that each have a float representing the temperature of a sensor. We could represent it as a *federated float* by"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7EJY0MHpLo8l"
      },
      "outputs": [],
      "source": [
        "federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JSQAXD0FLo8n"
      },
      "source": [
        "Federated types are specified by a type `T` of its member constituents (eg. `tf.float32`) and a group `G` of devices. We will focus on the cases where `G` is either `tff.CLIENTS` or `tff.SERVER`. Such a federated type is represented as `{T}@G`, as shown below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6mlPgubJLo8n"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'{float32}@CLIENTS'"
            ]
          },
          "execution_count": 12,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(federated_float_on_clients)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pjAQytkeLo8o"
      },
      "source": [
        "Why do we care so much about placements? A key goal of TFF is to enable writing code that could be deployed on a real distributed system. This means that it is vital to reason about which subsets of devices execute which code, and where different pieces of data reside.\n",
        "\n",
        "TFF focuses on three things: *data*, where the data is *placed*, and how the data is being *transformed*. The first two are encapsulated in federated types, while the last is encapsulated in *federated computations*."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZLT2FmVMLo8p"
      },
      "source": [
        "## Federated computations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-XwDC1vTLo8p"
      },
      "source": [
        "TFF is a strongly-typed functional programming environment whose basic units are *federated computations*. These are pieces of logic that accept federated values as input, and return federated values as output.\n",
        "\n",
        "For example, suppose we wanted to average the temperatures on our client sensors. We could define the following (using our federated float):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IfwXDNR1Lo8p"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))\n",
        "def get_average_temperature(client_temperatures):\n",
        "  return tff.federated_mean(client_temperatures)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iSgs6Te5Lo8r"
      },
      "source": [
        "You might ask, how is this different from the `tf.function` decorator in TensorFlow? The key answer is that the code generated by `tff.federated_computation` is neither TensorFlow nor Python code; It is a specification of a distributed system in an internal platform-independent *glue language*.\n",
        "\n",
        "While this may sound complicated, you can think of TFF computations as functions with well-defined type signatures. These type signatures can be directly queried."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mVq500KzG2mB"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'({float32}@CLIENTS -> float32@SERVER)'"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(get_average_temperature.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TveOYFfuLo8s"
      },
      "source": [
        "This `tff.federated_computation` accepts arguments of federated type `{float32}@CLIENTS`, and returns values of federated type `{float32}@SERVER`. Federated computations may also go from server to client, from client to client, or from server to server. Federated computations can also be composed like normal functions, as long as their type signatures match up.\n",
        "\n",
        "To support development, TFF allows you to invoke a `tff.federated_computation` as a Python function. For example, we can call"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PTowUHohG2mB"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "69.53334"
            ]
          },
          "execution_count": 15,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "get_average_temperature([68.5, 70.3, 69.8])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZXn-yje9RJ6H"
      },
      "source": [
        "## Non-eager computations and TensorFlow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nwyj8f3HLo8w"
      },
      "source": [
        "There are two key restrictions to be aware of. First, when the Python interpreter encounters a `tff.federated_computation` decorator, the function is traced once and serialized for future use. Due to the decentralized nature of Federated Learning, this future usage may occur elsewhere, such as a remote execution environment. Therefore, TFF computations are fundamentally *non-eager*. This behavior is somewhat analogous to that of the [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) decorator in TensorFlow.\n",
        "\n",
        "Second, a federated computation can only consist of federated operators (such as `tff.federated_mean`), they cannot contain TensorFlow operations. TensorFlow code must be confined to blocks decorated with `tff.tf_computation`. Most ordinary TensorFlow code can be directly decorated, such as the following function that takes a number and adds `0.5` to it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "huz3mNmMLo8w"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(tf.float32)\n",
        "def add_half(x):\n",
        "  return tf.add(x, 0.5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ptjWALDLo8y"
      },
      "source": [
        "These also have type signatures, but *without placements*. For example, we can call"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "34x5H2hzG2mC"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'(float32 -> float32)'"
            ]
          },
          "execution_count": 17,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(add_half.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WNjwrNMjLo8z"
      },
      "source": [
        "Here we see an important difference between `tff.federated_computation` and `tff.tf_computation`. The former has explicit placements, while the latter does not.\n",
        "\n",
        "We can use `tff.tf_computation` blocks in federated computations by specifying placements. Let's create a function that adds half, but only to federated floats at the clients. We can do this by using `tff.federated_map`, which applies a given `tff.tf_computation`, while preserving the placement."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pG6nw3wiLo80"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))\n",
        "def add_half_on_clients(x):\n",
        "  return tff.federated_map(add_half, x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h4msKRKJLo81"
      },
      "source": [
        "This function is almost identical to `add_half`, except that it only accepts values with placement at `tff.CLIENTS`, and returns values with the same placement. We can see this in its type signature:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x3H-oeWIG2mC"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'({float32}@CLIENTS -> {float32}@CLIENTS)'"
            ]
          },
          "execution_count": 19,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(add_half_on_clients.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3JxQ0DeiLo83"
      },
      "source": [
        "In summary:\n",
        "\n",
        "*   TFF operates on federated values.\n",
        "*   Each federated value has a *federated type*, with a *type* (eg. `tf.float32`) and a *placement* (eg. `tff.CLIENTS`).\n",
        "*   Federated values can be transformed using *federated computations*, which must be decorated with `tff.federated_computation` and a federated type signature.\n",
        "*   TensorFlow code must be contained in blocks with `tff.tf_computation` decorators. \n",
        "*   These blocks can then be incorporated into federated computations.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PvyFWox3Lo83"
      },
      "source": [
        "# Building your own Federated Learning algorithm, revisited\n",
        "\n",
        "Now that we've gotten a glimpse of the Federated Core, we can build our own federated learning algorithm. Remember that above, we defined an `initialize_fn` and `next_fn` for our algorithm. The `next_fn` will make use of the `client_update` and `server_update` we defined using pure TensorFlow code.\n",
        "\n",
        "However, in order to make our algorithm a federated computation, we will need both the `next_fn` and `initialize_fn` to each be a `tff.federated_computation`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CvY8fh1cLo84"
      },
      "source": [
        "## TensorFlow Federated blocks "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g0zNTO7LLo84"
      },
      "source": [
        "### Creating the initialization computation\n",
        "\n",
        "The initialize function will be quite simple: We will create a model using `model_fn`. However, remember that we must separate out our TensorFlow code using `tff.tf_computation`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jJY9xUBZLo84"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation\n",
        "def server_init():\n",
        "  model = model_fn()\n",
        "  return model.trainable_variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SGlv8LLgLo85"
      },
      "source": [
        "We can then pass this directly into a federated computation using `tff.federated_value`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m2hinzuRLo86"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation\n",
        "def initialize_fn():\n",
        "  return tff.federated_value(server_init(), tff.SERVER)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NFBghOgxLo88"
      },
      "source": [
        "### Creating the `next_fn`\n",
        "\n",
        "We now use our client and server update code to write the actual algorithm. We will first turn our `client_update` into a `tff.tf_computation` that accepts a client datasets and server weights, and outputs an updated client weights tensor.\n",
        "\n",
        "We will need the corresponding types to properly decorate our function. Luckily, the type of the server weights can be extracted directly from our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ph_noHN2Lo88"
      },
      "outputs": [],
      "source": [
        "whimsy_model = model_fn()\n",
        "tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WMPgpTaW66qx"
      },
      "source": [
        "Let's look at the dataset type signature. Remember that we took 28 by 28 images (with integer labels) and flattened them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GE2sYVA9G2mE"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'<float32[?,784],int32[?,1]>*'"
            ]
          },
          "execution_count": 23,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(tf_dataset_type)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kuS8d0BHLo8-"
      },
      "source": [
        "We can also extract the model weights type by using our `server_init` function above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4yx6CExMLo8-"
      },
      "outputs": [],
      "source": [
        "model_weights_type = server_init.type_signature.result"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mS-Eh6Xj7J15"
      },
      "source": [
        "Examining the type signature, we'll be able to see the architecture of our model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_s8eFsyvG2mE"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'<float32[784,10],float32[10]>'"
            ]
          },
          "execution_count": 25,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(model_weights_type)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g1U1wTGRLo8_"
      },
      "source": [
        "We can now create our `tff.tf_computation` for the client update."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q0W05pMWLo9A"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(tf_dataset_type, model_weights_type)\n",
        "def client_update_fn(tf_dataset, server_weights):\n",
        "  model = model_fn()\n",
        "  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
        "  return client_update(model, tf_dataset, server_weights, client_optimizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uP5quaAuLo9B"
      },
      "source": [
        "The `tff.tf_computation` version of the server update can be defined in a similar way, using types we've already extracted."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F4WvQtVzLo9B"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(model_weights_type)\n",
        "def server_update_fn(mean_client_weights):\n",
        "  model = model_fn()\n",
        "  return server_update(model, mean_client_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SImhLbu4Lo9D"
      },
      "source": [
        "Last, but not least, we need to create the `tff.federated_computation` that brings this all together. This function will accept two *federated values*, one corresponding to the server weights (with placement `tff.SERVER`), and the other corresponding to the client datasets (with placement `tff.CLIENTS`).\n",
        "\n",
        "Note that both these types were defined above! We simply need to give them the proper placement using `tff.FederatedType`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ekPsA8AsLo9D"
      },
      "outputs": [],
      "source": [
        "federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)\n",
        "federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7FXAX7vGLo9G"
      },
      "source": [
        "Remember the 4 elements of an FL algorithm?\n",
        "\n",
        "1. A server-to-client broadcast step.\n",
        "2. A local client update step.\n",
        "3. A client-to-server upload step.\n",
        "4. A server update step.\n",
        "\n",
        "Now that we've built up the above, each part can be compactly represented as a single line of TFF code. This simplicity is why we had to take extra care to specify things such as federated types!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Epc7MwfELo9G"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(federated_server_type, federated_dataset_type)\n",
        "def next_fn(server_weights, federated_dataset):\n",
        "  # Broadcast the server weights to the clients.\n",
        "  server_weights_at_client = tff.federated_broadcast(server_weights)\n",
        "\n",
        "  # Each client computes their updated weights.\n",
        "  client_weights = tff.federated_map(\n",
        "      client_update_fn, (federated_dataset, server_weights_at_client))\n",
        "  \n",
        "  # The server averages these updates.\n",
        "  mean_client_weights = tff.federated_mean(client_weights)\n",
        "\n",
        "  # The server updates its model.\n",
        "  server_weights = tff.federated_map(server_update_fn, mean_client_weights)\n",
        "\n",
        "  return server_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kWomG3TtLo9I"
      },
      "source": [
        "We now have a `tff.federated_computation` for both the algorithm initialization, and for running one step of the algorithm. To finish our algorithm, we pass these into `tff.templates.IterativeProcess`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GxdWgEddLo9I"
      },
      "outputs": [],
      "source": [
        "federated_algorithm = tff.templates.IterativeProcess(\n",
        "    initialize_fn=initialize_fn,\n",
        "    next_fn=next_fn\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Z__9k-Dc1I3"
      },
      "source": [
        "Let's look at the *type signature* of the `initialize` and `next` functions of our iterative process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EmyYgDNdG2mF"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'( -> <float32[784,10],float32[10]>@SERVER)'"
            ]
          },
          "execution_count": 31,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(federated_algorithm.initialize.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UyyEi5Kec90_"
      },
      "source": [
        "This reflects the fact that `federated_algorithm.initialize` is a no-arg function that returns a single-layer model (with a 784-by-10 weight matrix, and 10 bias units)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZRwHwQsCG2mG"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'"
            ]
          },
          "execution_count": 32,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(federated_algorithm.next.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "efpdHodmdU_6"
      },
      "source": [
        "Here, we see that `federated_algorithm.next` accepts a server model and client data, and returns an updated server model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4UYZ3qeMLo9N"
      },
      "source": [
        "## Evaluating the algorithm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jwd9Gs0ULo9O"
      },
      "source": [
        "Let's run a few rounds, and see how the loss changes. First, we will define an evaluation function using the *centralized* approach discussed in the second tutorial.\n",
        "\n",
        "We first create a centralized evaluation dataset, and then apply the same preprocessing we used for the training data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EdNgYoIwLo9P"
      },
      "outputs": [],
      "source": [
        "central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()\n",
        "central_emnist_test = preprocess(central_emnist_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7R50NZ35dphE"
      },
      "source": [
        "Next, we write a function that accepts a server state, and uses Keras to evaluate on the test dataset. If you're familiar with `tf.Keras`, this will all look familiar, though note the use of `set_weights`!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I5UEX4EWLo9Q"
      },
      "outputs": [],
      "source": [
        "def evaluate(server_state):\n",
        "  keras_model = create_keras_model()\n",
        "  keras_model.compile(\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  \n",
        "  )\n",
        "  keras_model.set_weights(server_state)\n",
        "  keras_model.evaluate(central_emnist_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hygoBACkLo9S"
      },
      "source": [
        "Now, let's initialize our algorithm and evaluate on the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {
        "id": "CDarZn71G2mH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2042/2042 [==============================] - 2s 767us/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027\n"
          ]
        }
      ],
      "source": [
        "server_state = federated_algorithm.initialize()\n",
        "evaluate(server_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2knqix2cLo9U"
      },
      "source": [
        "Let's train for a few rounds and see if anything changes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "id": "v1zBlzFILo9U"
      },
      "outputs": [],
      "source": [
        "for round in range(15):\n",
        "  server_state = federated_algorithm.next(server_state, federated_train_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "id": "2QDhaI_DG2mH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2042/2042 [==============================] - 2s 738us/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980\n"
          ]
        }
      ],
      "source": [
        "evaluate(server_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XM34ammUW-T3"
      },
      "source": [
        "We see a slight decrease in the loss function. While the jump is small, we've only performed 15 training rounds, and on a small subset of clients. To see better results, we may have to do hundreds if not thousands of rounds."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o13H5dDFXRFn"
      },
      "source": [
        "## Modifying our algorithm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qt4jVD21XTL-"
      },
      "source": [
        "At this point, let's stop and think about what we've accomplished. We've implemented Federated Averaging directly by combining pure TensorFlow code (for the client and server updates) with federated computations from the Federated Core of TFF.\n",
        "\n",
        "To perform more sophisticted learning, we can simply alter what we have above. In particular, by editing the pure TF code above, we can change how the client performs training, or how the server updates its model.\n",
        "\n",
        "**Challenge:** Add [gradient clipping](https://towardsdatascience.com/what-is-gradient-clipping-b8e815cdfb48) to the `client_update` function.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p7wvwgS7bCTy"
      },
      "source": [
        "If we wanted to make larger changes, we could also have the server store and broadcast more data. For example, the server could also store the client learning rate, and make it decay over time! Note that this will require changes to the type signatures used in the `tff.tf_computation` calls above.\n",
        "\n",
        "**Harder Challenge:** Implement Federated Averaging with learning rate decay on the clients.\n",
        "\n",
        "At this point, you may begin to realize how much flexibility there is in what you can implement in this framework. For ideas (including the answer to the harder challenge above) you can see the source-code for [`tff.learning.build_federated_averaging_process`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/build_federated_averaging_process), or check out various [research projects](https://github.com/google-research/federated) using TFF."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "building_your_own_federated_learning_algorithm.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
